
## partial negative log likelihood function
pnLL = function(par.update, parn = "not", pars, data){
    
    l0 = data[,c("l0.1","l0.2")]
    a0 = data[,"a0"];  l1 = data[,"l1"];  a1 = data[,"a1"];   y = data[,"y"]
    nb = nrow(data) 
    
    if(parn == "alpha1")     pars$alpha[1,]   = par.update
    if(parn == "alpha2345")  pars$alpha[2:5,] = par.update
    if(parn == "beta12")     pars$beta[1:2,]  = par.update
    if(parn == "beta3")      pars$beta[3,]    = par.update
    if(parn == "beta45")     pars$beta[4:5,]  = par.update
    # pars <- list(); pars$alpha = alpha.true; pars$beta <- beta.true # DEBUG
    alpha = pars$alpha
    beta  = pars$beta
    
    theta  = exp(l0 %*% t(alpha))
    phi12 = exp(l0 %*% t(beta[1:2,]))
    if (gop.type == "correct"){
      phi3 = exp(l0 %*% beta[3,])
    }else if (gop.type == "incorrect"){
      phi3 = exp(data[,c("l0.1","l0.2.tilde")] %*% beta[3,])
    }
    phi123 = cbind(phi12,phi3)
    phi45  = expit(l0 %*% t(beta[4:5,]))
    phi    = cbind(phi123, phi45)
    if (vectorize.get.prob){
      p.y.cond = GetProb(theta,
                         phi)
    }else{
      p.y.cond = t(sapply(1:nb, function(i) GetProb(theta[i,],
                                                    phi[i,])))
    }
    colnames(p.y.cond) = c("p000","p001","p010","p011","p100",
                           "p101","p110","p111")
    a0l1a1 = 4*a0 + 2*l1 + a1 + 1
    # p.y = sapply(1:nb, function(i) p.y.cond[i,a0l1a1[i]])
    # p.l1 = sapply(1:nb, function(i) phi45[i,2 - a0[i]])
    p.y =  p.y.cond[cbind(1:nb, a0l1a1)]
    p.l1 = phi45[cbind(1:nrow(phi45), 2-a0)]
    
    return(-sum(cnt*(  (1-y)*log(1-p.y) + y*log(p.y)   )))
}


PMLEst4 = function(data, cnt, max.step = 500, thres = 1e-4){
    
    Optim = function(par.update, fn, parn, pars, data){
        optim(par.update, fn, parn=parn, pars = pars, data = data, 
              control=list(maxit=max.step), method =  "L-BFGS-B",
              lower = -rep(10,10), upper = rep(10,10))
    } 
    
    pars = list(alpha = matrix(rep(0,10),5,2), beta = matrix(rep(0,10),5,2))
    
    ## First estimate phi[4:5]
    l0 = data[,c("l0.1","l0.2")]
    a0 = data[,"a0"];  l1 = data[,"l1"];  a1 = data[,"a1"];   y = data[,"y"]
    nb = nrow(data) 
    # if (gop.type == "correct"){
    #   pars$beta[4,] = glm(l1 ~ l0 - 1, subset = a0 == 1, family="binomial", 
    #                       weights = cnt)$coefficients
    #   pars$beta[5,] = glm(l1 ~ l0 - 1, subset = a0 == 0, family="binomial",
    #                       weights = cnt)$coefficients
    # }else if(gop.type == "incorrect"){
    #   pars$beta[4,] <- pars$beta[5,] <- c(0.3, 0.3)
    # }
    
      pars$beta[4,] = glm(l1 ~ l0 - 1, subset = a0 == 1, family="binomial",
                          weights = cnt)$coefficients
      pars$beta[5,] = glm(l1 ~ l0 - 1, subset = a0 == 0, family="binomial",
                          weights = cnt)$coefficients
    ## Optimization 
    ptm = proc.time()
    Diff = function(x,y) sum((x-y)^2)/sum(x^2+thres)
    diff = thres + 1; step = 0
    while(diff > thres & step < max.step){
        if(step %% 10 == 0) cat("This is the ", step, "th step. Diff = ", diff, "\n")
        step = step + 1
        opt1 = Optim(pars$alpha[1,], nLL, "alpha1", pars, data)
        diff1 = Diff(opt1$par,pars$alpha[1,])
        pars$alpha[1,] = opt1$par
        
        opt2 = Optim(pars$alpha[2:5,], nLL, "alpha2345", pars, data)
        diff2 = Diff(opt2$par,pars$alpha[2:5,])
        pars$alpha[2:5,] = opt2$par
        
        opt3 = Optim(pars$beta[1:2,], nLL, "beta12", pars, data)
        diff3 = Diff(opt3$par,pars$beta[1:2,])
        pars$beta[1:2,] = opt3$par
        
        opt4 = Optim(pars$beta[3,], nLL, "beta3", pars, data)
        diff4 = Diff(opt4$par,pars$beta[3,])
        pars$beta[3,] = opt4$par
        
        diff = max(diff1, diff2, diff3, diff4)
        
    }
    proc.time() - ptm
    
    
    alpha     = pars$alpha
    beta      = pars$beta
    mle.est   = abind(pars$alpha,pars$beta, along=3)
    
    nll  = pnLL(0, "not", pars, data)
    time = (proc.time() - ptm)[3] 
    
    nll  = array(nll, dim(mle.est))
    time = array(time, dim(mle.est))
    
    contrast = GetContrast(data, alpha, beta)
    contrast = array(contrast, dim(mle.est))
    
    output = abind(mle.est, nll, time, contrast, along = 4)
    
    return(output)
}






